如何将PyTorch模型导出到文件(Python)并使用TorchScript加载它(libtorsC++)?

您所在的位置:网站首页 windows 正在加载文件 如何将PyTorch模型导出到文件(Python)并使用TorchScript加载它(libtorsC++)?

如何将PyTorch模型导出到文件(Python)并使用TorchScript加载它(libtorsC++)?

2023-04-26 12:36| 来源: 网络整理| 查看: 265

我正在为PyTorch数据的(反)序列化而奋斗。在使用PyTorch (使用GPU)进行培训之后,我希望将模型保存到PT(H)文件中。接下来,我想在C++上下文中加载这个序列化模型(使用libtorch)。目前,我只是在试验基本的导出/导入功能,以掌握它的诀窍。

代码如下所示。我得到了以下错误:

Error loading model Unrecognized data format Exception raised from load at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\jit\serialization\import.cpp:449 (most recent call first): 00007FFBB1FFDA2200007FFBB1FFD9C0 c10.dll!c10::Error::Error [ @ ] 00007FFBB1FFD43E00007FFBB1FFD3F0 c10.dll!c10::detail::torchCheckFail [ @ ] 00007FFB4B87B54700007FFB4B87B4E0 torch_cpu.dll!torch::jit::load [ @ ] 00007FFB4B87B42A00007FFB4B87B380 torch_cpu.dll!torch::jit::load [ @ ] 00007FF6089A737A00007FF6089A7210 pytroch_load_model.exe!main [c:\users\USER\projects\cmake dx cuda pytorch\cmake_integration_examples\pytorch\src\pytroch_load_model.cpp @ 19] 00007FF6089D8A9400007FF6089D8A60 pytroch_load_model.exe!invoke_main [d:\agent\_work\2\s\src\vctools\crt\vcstartup\src\startup\exe_common.inl @ 79] 00007FF6089D893E00007FF6089D8810 pytroch_load_model.exe!__scrt_common_main_seh [d:\agent\_work\2\s\src\vctools\crt\vcstartup\src\startup\exe_common.inl @ 288] 00007FF6089D87FE00007FF6089D87F0 pytroch_load_model.exe!__scrt_common_main [d:\agent\_work\2\s\src\vctools\crt\vcstartup\src\startup\exe_common.inl @ 331] 00007FF6089D8B2900007FF6089D8B20 pytroch_load_model.exe!mainCRTStartup [d:\agent\_work\2\s\src\vctools\crt\vcstartup\src\startup\exe_main.cpp @ 17] 00007FFBDF8C703400007FFBDF8C7020 KERNEL32.DLL!BaseThreadInitThunk [ @ ] 00007FFBDFBA265100007FFBDFBA2630 ntdll.dll!RtlUserThreadStart [ @ ]

以下是代码:

Python (PyTorch):

import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x class TestModel(nn.Module): def __init__(self): super(TestModel, self).__init__() self.x = 2 def forward(self): return self.x test_net = torch.jit.script(Net()) test_module = torch.jit.script(TestModel()) torch.jit.save(test_net, 'test_net.pt') torch.jit.save(test_module, 'test_module.pt')

C++ (激光手电筒)

#include #include #include int main(int argc, const char* argv[]) { if (argc != 2) { std::cerr


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3